Skip to content

Conversation

@ADream-ki
Copy link
Contributor

PR types

Others

PR changes

Others

Describe

添加CoNFILD案例

@paddle-bot
Copy link

paddle-bot bot commented Oct 19, 2025

Thanks for your contribution!

ADream-ki and others added 7 commits November 2, 2025 11:00
…andling and loss logging. Update _update_ema function to work with parameter dictionaries. Enhance log_loss_dict to control loss aggregation during validation. Modify GaussianDiffusion class to ensure loss calculation includes valid_mse and vb terms correctly.
…dability and consistency. Update variable names for clarity, enhance gradient handling, and streamline loss logging. Fix tensor type conversions and ensure proper handling of model parameters during training and evaluation.
…irectory exists before saving model parameters. This change improves file management during training by organizing saved models into a specified output directory.
…for dynamic adjustment while defaulting to 15698. This enhances flexibility in model evaluation.
Comment on lines +23 to +25
git clone https://github.com/PaddlePaddle/PaddleScience.git
cd PaddleScience/examples/confild
python confild.py mode=train
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
git clone https://github.com/PaddlePaddle/PaddleScience.git
cd PaddleScience/examples/confild
python confild.py mode=train
python confild.py mode=train

Comment on lines +86 to +178
```python
class Normalizer_ts(object):
def __init__(self, params=[], method="-11", dim=None):
self.params = params
self.method = method
self.dim = dim

def fit_normalize(self, data):
assert type(data) == paddle.Tensor
if len(self.params) == 0:
if self.method == "-11" or self.method == "01":
if self.dim is None:
self.params = paddle.max(x=data), paddle.min(x=data)
else:
self.params = (
paddle.max(keepdim=True, x=data, axis=self.dim),
paddle.argmax(keepdim=True, x=data, axis=self.dim),
)[0], (
paddle.min(keepdim=True, x=data, axis=self.dim),
paddle.argmin(keepdim=True, x=data, axis=self.dim),
)[
0
]
elif self.method == "ms":
if self.dim is None:
self.params = paddle.mean(x=data, axis=self.dim), paddle.std(
x=data, axis=self.dim
)
else:
self.params = paddle.mean(
x=data, axis=self.dim, keepdim=True
), paddle.std(x=data, axis=self.dim, keepdim=True)
elif self.method == "none":
self.params = None
return self.fnormalize(data, self.params, self.method)

def normalize(self, new_data):
if not new_data.place == self.params[0].place:
self.params = self.params[0].to(new_data.place), self.params[1].to(
new_data.place
)
return self.fnormalize(new_data, self.params, self.method)

def denormalize(self, new_data_norm):
if not new_data_norm.place == self.params[0].place:
self.params = self.params[0].to(new_data_norm.place), self.params[1].to(
new_data_norm.place
)
return self.fdenormalize(new_data_norm, self.params, self.method)

def get_params(self):
if self.method == "ms":
print("returning mean and std")
elif self.method == "01":
print("returning max and min")
elif self.method == "-11":
print("returning max and min")
elif self.method == "none":
print("do nothing")
return self.params

@staticmethod
def fnormalize(data, params, method):
if method == "-11":
return (data - params[1].to(data.place)) / (
params[0].to(data.place) - params[1].to(data.place)
) * 2 - 1
elif method == "01":
return (data - params[1].to(data.place)) / (
params[0].to(data.place) - params[1].to(data.place)
)
elif method == "ms":
return (data - params[0].to(data.place)) / params[1].to(data.place)
elif method == "none":
return data

@staticmethod
def fdenormalize(data_norm, params, method):
if method == "-11":
return (data_norm + 1) / 2 * (
params[0].to(data_norm.place) - params[1].to(data_norm.place)
) + params[1].to(data_norm.place)
elif method == "01":
return data_norm * (
params[0].to(data_norm.place) - params[1].to(data_norm.place)
) + params[1].to(data_norm.place)
elif method == "ms":
return data_norm * params[1].to(data_norm.place) + params[0].to(
data_norm.place
)
elif method == "none":
return data_norm
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档内的代码片段可以使用引用语法,引用.py文件里的代码

Comment on lines +184 to +290
class SIRENAutodecoder_film(paddle.nn.Layer):
"""
siren network with author decoding
Args:
input_keys (Tuple[str,...], optional): Key to get the input tensor from the dict.
output_keys (Tuple[str,...], optional): Key to save the output tensor into the dict.
in_coord_features (int, optional): Number of input coordinates features
in_latent_features (int, optional): Number of input latent features
out_features (int, optional): Number of output features
num_hidden_layers (int, optional): Number of hidden layers
hidden_features (int, optional): Number of hidden features
outermost_linear (bool, optional): Whether to use linear layer at the end. Defaults to False.
nonlinearity (str, optional): Nonlinearity to use. Defaults to "sine".
weight_init (Callable, optional): Weight initialization function. Defaults to None.
bias_init (Callable, optional): Bias initialization function. Defaults to None.
premap_mode (str, optional): Feature mapping mode. Defaults to None.
Examples:
>>> model = ppsci.arch.SIRENAutodecoder_film(
input_keys=["input1", "input2"],
output_keys=("output",),
in_coord_features=2,
in_latent_features=128,
out_features=3,
num_hidden_layers=10,
hidden_features=128,
)
>>> input_data = {"input1": paddle.randn([10, 2]), "input2": paddle.randn([10, 128])}
>>> out_dict = model(input_data)
>>> for k, v in out_dict.items():
... print(k, v.shape)
output [22, 918, 3]
"""

def __init__(
self,
input_keys,
output_keys,
in_coord_features,
in_latent_features,
out_features,
num_hidden_layers,
hidden_features,
outermost_linear=False,
nonlinearity="sine",
weight_init=None,
bias_init=None,
premap_mode=None,
**kwargs,
):
super().__init__()
self.input_keys = input_keys
self.output_keys = output_keys

self.premap_mode = premap_mode
if self.premap_mode is not None:
self.premap_layer = FeatureMapping(
in_coord_features, mode=premap_mode, **kwargs
)
in_coord_features = self.premap_layer.dim
self.first_layer_init = None
self.nl, nl_weight_init, first_layer_init = NLS_AND_INITS[nonlinearity]
if weight_init is not None:
self.weight_init = weight_init
else:
self.weight_init = nl_weight_init
self.net1 = paddle.nn.LayerList(
sublayers=[BatchLinear(in_coord_features, hidden_features)]
+ [
BatchLinear(hidden_features, hidden_features)
for i in range(num_hidden_layers)
]
+ [BatchLinear(hidden_features, out_features)]
)
self.net2 = paddle.nn.LayerList(
sublayers=[
BatchLinear(in_latent_features, hidden_features, bias_attr=False)
for i in range(num_hidden_layers + 1)
]
)
if self.weight_init is not None:
self.net1.apply(self.weight_init)
self.net2.apply(self.weight_init)
if first_layer_init is not None:
self.net1[0].apply(first_layer_init)
self.net2[0].apply(first_layer_init)
if bias_init is not None:
self.net2.apply(bias_init)

def forward(self, input_data):
coords = input_data[self.input_keys[0]]
latents = input_data[self.input_keys[1]]
if self.premap_mode is not None:
x = self.premap_layer(coords)
else:
x = coords

for i in range(len(self.net1) - 1):
x = self.net1[i](x) + self.net2[i](latents)
x = self.nl(x)
x = self.net1[-1](x)
return {self.output_keys[0]: x}

def disable_gradient(self):
for param in self.parameters():
param.stop_gradient = not False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,引用

4.3 模型训练、评估
完成上述设置之后,只需要将上述实例化的对象按照文档进行组合,然后启动训练、评估。
```python
def signal_train(cfg, normed_coords, normed_fois, spatio_axis, out_normalizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,引用

```

## 5. 实验结果

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实验结果没有?

Comment on lines +331 to +332
# 显示图形
plt.show()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.show可以删除,图片保存即可

Suggested change
# 显示图形
plt.show()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

commit之前请安装pre-commit,并且逐个对上传的文件调用pre-commit run进行格式化

"LatentNO",
"LatentNO_time",
"LNO",
"LossType",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XXType不需要作为公开接口,可以删除

Comment on lines +143 to +182
class BatchLinear(paddle.nn.Linear):
"""
Batch-wise linear transformation layer that supports manual parameter injection.
This layer extends paddle.nn.Linear to allow passing parameters explicitly,
which is useful for meta-learning and hypernetwork applications.
Args:
in_features (int): Size of input features.
out_features (int): Size of output features.
Note:
- Weight shape: (out_features, in_features)
- Bias shape: (out_features,)
"""

__doc__ = paddle.nn.Linear.__doc__

def forward(self, input, params=None):
"""
Forward pass with optional external parameters.
Args:
input (paddle.Tensor): Input tensor of shape (..., in_features).
params (OrderedDict, optional): External parameters dict containing 'weight' and optionally 'bias'.
If None, uses internal parameters. Defaults to None.
Returns:
paddle.Tensor: Output tensor of shape (..., out_features).
"""
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get("bias", None)
weight = params["weight"]

output = paddle.matmul(x=input, y=weight)
if bias is not None:
output += bias.unsqueeze(axis=-2)
return output

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BatchLinear跟nn.Linear的区别是什么,如果没区别是不是可以删掉

Comment on lines +246 to +250
self.centers = paddle.base.framework.EagerParamBase.from_tensor(
tensor=paddle.empty(
shape=(rbf_out_features, in_features), dtype="float32"
)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

较新的paddle可以直接用torch类似的语法:self.centers = nn.Paramter(...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants